{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import math\n",
    "from math import sqrt\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn import datasets\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPS = 1.e-7"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**DISCLAIMER**\n",
    "\n",
    "The presented code is not optimized, it serves an educational purpose. It is written for CPU, it uses only fully-connected networks and an extremely simplistic dataset. However, it contains all components that can help to understand how neural compression works, and it should be rather easy to extend it to more sophisticated models. This code could be run almost on any laptop/PC, and it takes a couple of minutes top to get the result."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we go wild and use a dataset that is simpler than MNIST! We use a scipy dataset called Digits. It consists of ~1500 images of size 8x8, and each pixel can take values in $\\{0, 1, \\ldots, 16\\}$.\n",
    "\n",
    "The goal of using this dataset is that everyone can run it on a laptop, without any gpu etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Digits(Dataset):\n",
    "    \"\"\"Scikit-Learn Digits dataset.\"\"\"\n",
    "\n",
    "    def __init__(self, mode='train', transforms=None):\n",
    "        digits = load_digits()\n",
    "        if mode == 'train':\n",
    "            self.data = digits.data[:1000].astype(np.float32)\n",
    "        elif mode == 'val':\n",
    "            self.data = digits.data[1000:1350].astype(np.float32)\n",
    "        else:\n",
    "            self.data = digits.data[1350:].astype(np.float32)\n",
    "        \n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sample = self.data[idx]\n",
    "        if self.transforms:\n",
    "            sample = self.transforms(sample)\n",
    "        return sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Auxiliary code"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Auxiliary code for running some parts, e.g., Causal Convolution 1D for the autoregressive model (ARM)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Causal Convolution for ARM**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CausalConv1d(nn.Module):\n",
    "    \"\"\"\n",
    "    A causal 1D convolution.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, in_channels, out_channels, kernel_size, dilation, A=False, **kwargs):\n",
    "        super(CausalConv1d, self).__init__()\n",
    "\n",
    "        # attributes:\n",
    "        self.kernel_size = kernel_size\n",
    "        self.dilation = dilation\n",
    "        self.A = A\n",
    "        \n",
    "        self.padding = (kernel_size - 1) * dilation + A * 1\n",
    "\n",
    "        # module:\n",
    "        self.conv1d = torch.nn.Conv1d(in_channels, out_channels,\n",
    "                                      kernel_size, stride=1,\n",
    "                                      padding=0,\n",
    "                                      dilation=dilation,\n",
    "                                      **kwargs)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.nn.functional.pad(x, (self.padding, 0))\n",
    "        conv1d_out = self.conv1d(x)\n",
    "        if self.A:\n",
    "            return conv1d_out[:, :, : -1]\n",
    "        else:\n",
    "            return conv1d_out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Neural Compression code"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please see the blogpost for details."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Quantizer**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The quantizer is the crucial component of a neural compressor. It consists of a codebook, a vector of floats. It takes a real-valued input and replaces them with the closests values in the codebook.\n",
    "Please note that we use a real-valued codebook, however, in practice, we can implement it using integers. As a result, we use $K$ bits."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Quantizer(nn.Module):\n",
    "    def __init__(self, input_dim, codebook_dim, temp=1.e7):\n",
    "        super(Quantizer, self).__init__()\n",
    "        \n",
    "        #temperature for softmax\n",
    "        self.temp = temp\n",
    "        \n",
    "        # dimensionality of the inputs and the codebook\n",
    "        self.input_dim = input_dim\n",
    "        self.codebook_dim = codebook_dim\n",
    "        \n",
    "        # codebook layer (a codebook)\n",
    "        # - we initialize it uniformly\n",
    "        # - we make it Parameter, namely, it is learnable\n",
    "        self.codebook = nn.Parameter(torch.FloatTensor(1, self.codebook_dim,).uniform_(-1/self.codebook_dim, 1/self.codebook_dim))\n",
    "    \n",
    "    # A function for codebook indices (a one-hot representation) to values in the codebook.\n",
    "    def indices2codebook(self, indices_onehot):\n",
    "        return torch.matmul(indices_onehot, self.codebook.t()).squeeze()\n",
    "    \n",
    "    # A function to change integers to a one-hot representation.\n",
    "    def indices_to_onehot(self, inputs_shape, indices):\n",
    "        indices_hard = torch.zeros(inputs_shape[0], inputs_shape[1], self.codebook_dim)\n",
    "        indices_hard.scatter_(2, indices, 1)\n",
    "    \n",
    "    # The forward function:\n",
    "    # - First, distances are calculated between input values and codebook values.\n",
    "    # - Second, indices (soft - differentiable, hard - non-differentiable) between the encoded values and the codebook values are calculated.\n",
    "    # - Third, the quantizer returns indices and quantized code (the output of the encoder).\n",
    "    # - Fourth, the decoder maps the quantized code to the obeservable space (i.e., it decodes the code back).\n",
    "    def forward(self, inputs):\n",
    "        # inputs - a matrix of floats, B x M\n",
    "        inputs_shape = inputs.shape\n",
    "        # repeat inputs\n",
    "        inputs_repeat = inputs.unsqueeze(2).repeat(1, 1, self.codebook_dim)\n",
    "        # calculate distances between input values and the codebook values\n",
    "        distances = torch.exp(-torch.sqrt(torch.pow(inputs_repeat - self.codebook.unsqueeze(1), 2)))\n",
    "        \n",
    "        # indices (hard, i.e., nondiff)\n",
    "        indices = torch.argmax(distances, dim=2).unsqueeze(2)\n",
    "        indices_hard = self.indices_to_onehot(inputs_shape=inputs_shape, indices=indices)\n",
    "        \n",
    "        # indices (soft, i.e., diff)\n",
    "        indices_soft = torch.softmax(self.temp * distances, -1)\n",
    "        \n",
    "        # quantized values: we use soft indices here because it allows backpropagation\n",
    "        quantized = self.indices2codebook(indices_onehot=indices_soft)\n",
    "        \n",
    "        return (indices_soft, indices_hard, quantized)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Encoder**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The encoder is simply a neural network that takes an image and outputs a corresponding code.\n",
    "class Encoder(nn.Module):\n",
    "    def __init__(self, encoder_net):\n",
    "        super(Encoder, self).__init__()\n",
    "\n",
    "        self.encoder = encoder_net\n",
    "\n",
    "    def encode(self, x):\n",
    "        h_e = self.encoder(x)\n",
    "        return h_e\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.encode(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Decoder**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The decoder is simply a neural network that takes a quantized code and returns an image.\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self, decoder_net):\n",
    "        super(Decoder, self).__init__()\n",
    "\n",
    "        self.decoder = decoder_net\n",
    "\n",
    "    def decode(self, z):\n",
    "        h_d = self.decoder(z)\n",
    "        return h_d\n",
    "\n",
    "    def forward(self, z, x=None):\n",
    "        x_rec = self.decode(z)\n",
    "        return x_rec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Entropy Coding**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Entropy coding is the crucial step in the compression scheme. At this point, we have a quantized code that we want to transmit. In order to send the quantized code, which is typically represented by discerete (non-binary) symbols, it must be translated into a bitstream (a stream o bits).\n",
    "\n",
    "An entropy coder assigns a unique prefix-free code (e.g., unique binary codes like Huffman codes) to each unique symbol that occurs in the input. Two of the most common entropy encoding techniques are Huffman coding and arithmetic coding that require knowing the (estimates of) probabilities of the symbols.\n",
    "\n",
    "In the following code, we present a non-learnable, uniform distribution over symbols, a learnable, independent distributions over symbols (i.e., the product of categorical distributions) and an auto-regressive model for entropy coding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class UniformEntropyCoding(nn.Module):\n",
    "    def __init__(self, code_dim, codebook_dim):\n",
    "        super(UniformEntropyCoding, self).__init__()\n",
    "        self.code_dim = code_dim\n",
    "        self.codebook_dim = codebook_dim\n",
    "        \n",
    "        self.probs = torch.softmax(torch.ones(1, self.code_dim, self.codebook_dim), -1)\n",
    "    \n",
    "    def sample(self, quantizer=None, B=10):\n",
    "        code = torch.zeros(B, self.code_dim, self.codebook_dim)\n",
    "        for b in range(B):\n",
    "            indx = torch.multinomial(torch.softmax(self.probs, -1).squeeze(0), 1).squeeze()\n",
    "            for i in range(self.code_dim):\n",
    "                code[b,i,indx[i]] = 1\n",
    "        \n",
    "        code = quantizer.indices2codebook(code)\n",
    "        return code\n",
    "    \n",
    "    def forward(self, z, x=None):\n",
    "        p = torch.clamp(self.probs, EPS, 1. - EPS)\n",
    "        return -torch.sum(z * torch.log(p), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class IndependentEntropyCoding(nn.Module):\n",
    "    def __init__(self, code_dim, codebook_dim):\n",
    "        super(IndependentEntropyCoding, self).__init__()\n",
    "        self.code_dim = code_dim\n",
    "        self.codebook_dim = codebook_dim\n",
    "        \n",
    "        self.probs = nn.Parameter(torch.ones(1, self.code_dim, self.codebook_dim))\n",
    "    \n",
    "    def sample(self, quantizer=None, B=10):\n",
    "        code = torch.zeros(B, self.code_dim, self.codebook_dim)\n",
    "        for b in range(B):\n",
    "            indx = torch.multinomial(torch.softmax(self.probs, -1).squeeze(0), 1).squeeze()\n",
    "            for i in range(self.code_dim):\n",
    "                code[b,i,indx[i]] = 1\n",
    "        \n",
    "        code = quantizer.indices2codebook(code)\n",
    "        return code\n",
    "    \n",
    "    def forward(self, z, x=None):\n",
    "        p = torch.clamp(torch.softmax(self.probs, -1), EPS, 1. - EPS)\n",
    "        return -torch.sum(z * torch.log(p), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ARMEntropyCoding(nn.Module):\n",
    "    def __init__(self, code_dim, codebook_dim, arm_net):\n",
    "        super(ARMEntropyCoding, self).__init__()\n",
    "        self.code_dim = code_dim\n",
    "        self.codebook_dim = codebook_dim\n",
    "        self.arm_net = arm_net # it takes B x 1 x code_dim and outputs B x codebook_dim x code_dim\n",
    "    \n",
    "    def f(self, x):\n",
    "        h = self.arm_net(x.unsqueeze(1))\n",
    "        h = h.permute(0, 2, 1)\n",
    "        p = torch.softmax(h, 2)\n",
    "        \n",
    "        return p\n",
    "    \n",
    "    def sample(self, quantizer=None, B=10):\n",
    "        x_new = torch.zeros((B, self.code_dim))\n",
    "        \n",
    "        for d in range(self.code_dim):\n",
    "            p = self.f(x_new)\n",
    "            indx_d = torch.multinomial(p[:, d, :], num_samples=1)\n",
    "            codebook_value = quantizer.codebook[0, indx_d].squeeze()\n",
    "            x_new[:, d] = codebook_value\n",
    "        \n",
    "        return x_new\n",
    "\n",
    "    def forward(self, z, x):\n",
    "        p = self.f(x)\n",
    "        return -torch.sum(z * torch.log(p), 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Full Neural Compressor**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NeuralCompressor(nn.Module):\n",
    "    def __init__(self, encoder, decoder, entropy_coding, quantizer, beta=1., detaching=False):\n",
    "        super(NeuralCompressor, self).__init__()\n",
    "\n",
    "        print('VAE by JT.')\n",
    "            \n",
    "        # we \n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        self.entropy_coding = entropy_coding\n",
    "        self.quantizer = quantizer\n",
    "        \n",
    "        # beta determines how strongly we focus on compression against reconstruction quality\n",
    "        self.beta = beta\n",
    "        \n",
    "        # We can detach inputs to the rate, then we learn rate and distortion separately\n",
    "        self.detaching = detaching\n",
    "\n",
    "    def forward(self, x, reduction='avg'):\n",
    "        # encoding\n",
    "        #-non-quantized values\n",
    "        z = self.encoder(x)\n",
    "        #-quantizing\n",
    "        quantizer_out = self.quantizer(z)\n",
    "        \n",
    "        # decoding\n",
    "        x_rec = self.decoder(quantizer_out[2])\n",
    "        \n",
    "        # Distortion (e.g., MSE)\n",
    "        Distortion = torch.mean(torch.pow(x - x_rec, 2), 1)\n",
    "        \n",
    "        # Rate: we use the entropy coding here\n",
    "        Rate = torch.mean(self.entropy_coding(quantizer_out[0], quantizer_out[2]), 1)\n",
    "        \n",
    "        # Objective\n",
    "        objective = Distortion + self.beta * Rate\n",
    "        \n",
    "        if reduction == 'sum':\n",
    "            return objective.sum(), Distortion.sum(), Rate.sum()\n",
    "        else:\n",
    "            return objective.mean(), Distortion.mean(), Rate.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Auxiliary functions: training, evaluation, plotting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's rather self-explanatory, isn't it?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluation(test_loader, name=None, model_best=None, epoch=None):\n",
    "    # EVALUATION\n",
    "    if model_best is None:\n",
    "        # load best performing model\n",
    "        model_best = torch.load(name + '.model')\n",
    "\n",
    "    model_best.eval()\n",
    "    loss = 0.\n",
    "    distortion = 0.\n",
    "    rate = 0.\n",
    "    N = 0.\n",
    "    for indx_batch, test_batch in enumerate(test_loader):\n",
    "        loss_t, distortion_t, rate_t = model_best.forward(test_batch, reduction='sum')\n",
    "        loss = loss + loss_t.item()\n",
    "        distortion = distortion + distortion_t.item()\n",
    "        rate = rate + rate_t.item()\n",
    "        N = N + test_batch.shape[0]\n",
    "    loss = loss / N\n",
    "    distortion = distortion/N\n",
    "    rate = rate / N\n",
    "\n",
    "    if epoch is None:\n",
    "        print(f'FINAL LOSS: objective={loss} (distortion={distortion}, rate={rate})')\n",
    "    else:\n",
    "        print(f'Epoch: {epoch}, objective val={loss} (distortion={distortion}, rate={rate})')\n",
    "\n",
    "    return loss, distortion, rate\n",
    "\n",
    "def plot_curve(name, nll_val, metric_name='loss'):\n",
    "    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')\n",
    "    plt.xlabel('epochs')\n",
    "    plt.ylabel(metric_name)\n",
    "    plt.savefig(name + metric_name + '_val_curve.pdf', bbox_inches='tight')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader):\n",
    "    objective_loss_val = []\n",
    "    objective_distortion_val = []\n",
    "    objective_rate_val = []\n",
    "    loss_best = 1000.\n",
    "    patience = 0\n",
    "\n",
    "    # Main loop\n",
    "    for e in range(num_epochs):\n",
    "        # TRAINING\n",
    "        model.train()\n",
    "        for indx_batch, batch in enumerate(training_loader):\n",
    "            if hasattr(model, 'dequantization'):\n",
    "                if model.dequantization:\n",
    "                    batch = batch + torch.rand(batch.shape)\n",
    "            loss, _, _ = model.forward(batch)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer.step()\n",
    "\n",
    "        # Validation\n",
    "        loss_val, distortion_val, rate_val = evaluation(val_loader, model_best=model, epoch=e)\n",
    "        objective_loss_val.append(loss_val)  # save for plotting\n",
    "        objective_distortion_val.append(distortion_val)  # save for plotting\n",
    "        objective_rate_val.append(rate_val)  # save for plotting\n",
    "\n",
    "        if e == 0:\n",
    "            print('saved!')\n",
    "            torch.save(model, name + '.model')\n",
    "            loss_best = loss_val\n",
    "        else:\n",
    "            if loss_val < loss_best:\n",
    "                print('saved!')\n",
    "                torch.save(model, name + '.model')\n",
    "                loss_best = loss_val\n",
    "                patience = 0\n",
    "            else:\n",
    "                patience = patience + 1\n",
    "\n",
    "        if patience > max_patience:\n",
    "            break\n",
    "\n",
    "    objective_loss_val = np.asarray(objective_loss_val)\n",
    "    objective_distortion_val = np.asarray(objective_distortion_val)\n",
    "    objective_rate_val = np.asarray(objective_rate_val)\n",
    "\n",
    "    return objective_loss_val, objective_distortion_val, objective_rate_val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = Digits(mode='train')\n",
    "val_data = Digits(mode='val')\n",
    "test_data = Digits(mode='test')\n",
    "\n",
    "training_loader = DataLoader(train_data, batch_size=64, shuffle=True)\n",
    "val_loader = DataLoader(val_data, batch_size=64, shuffle=False)\n",
    "test_loader = DataLoader(test_data, batch_size=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hyperparams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entropy_coding_type = 'arm' # arm or indp or uniform\n",
    "D = 64   # input dimension\n",
    "C = 16  # code length\n",
    "E = 8 # codebook size (i.e., the number of quantized values)\n",
    "M = 256  # the number of neurons\n",
    "M_kernels = 32 # the number of kernels in causal conv1d layers\n",
    "\n",
    "# beta: how much we weight rate\n",
    "if entropy_coding_type == 'uniform':\n",
    "    beta = 0. \n",
    "else:\n",
    "    beta = 1.\n",
    "\n",
    "lr = 1e-3 # learning rate\n",
    "num_epochs = 1000 # max. number of epochs\n",
    "max_patience = 50 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_dir = 'results/'\n",
    "if not(os.path.exists(result_dir)):\n",
    "    os.mkdir(result_dir)\n",
    "name = 'neural_compressor_' + entropy_coding_type + '_C_' + str(C) + '_E_' + str(E)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialize Neural Compressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ENCODER\n",
    "encoder_net = nn.Sequential(nn.Linear(D, M*2), nn.BatchNorm1d(M*2), nn.ReLU(),\n",
    "                            nn.Linear(M*2, M), nn.BatchNorm1d(M), nn.ReLU(),\n",
    "                            nn.Linear(M, M//2), nn.BatchNorm1d(M//2), nn.ReLU(),\n",
    "                            nn.Linear(M//2, C))\n",
    "\n",
    "encoder = Encoder(encoder_net=encoder_net)\n",
    "\n",
    "# DECODER\n",
    "decoder_net = nn.Sequential(nn.Linear(C, M//2), nn.BatchNorm1d(M//2), nn.ReLU(),\n",
    "                            nn.Linear(M//2, M), nn.BatchNorm1d(M), nn.ReLU(),\n",
    "                            nn.Linear(M, M*2), nn.BatchNorm1d(M*2), nn.ReLU(),\n",
    "                            nn.Linear(M*2, D))\n",
    "\n",
    "decoder = Decoder(decoder_net=decoder_net)\n",
    "\n",
    "# QUANTIZER\n",
    "quantizer = Quantizer(input_dim=C, codebook_dim=E)\n",
    "\n",
    "# ENTROPY CODING\n",
    "if entropy_coding_type == 'uniform':\n",
    "    entropy_coding = UniformEntropyCoding(code_dim=C, codebook_dim=E)\n",
    "    \n",
    "elif entropy_coding_type == 'indp':\n",
    "    entropy_coding = IndependentEntropyCoding(code_dim=C, codebook_dim=E)\n",
    "\n",
    "elif entropy_coding_type == 'arm':\n",
    "    kernel = 4\n",
    "    arm_net = nn.Sequential(\n",
    "        CausalConv1d(in_channels=1, out_channels=M_kernels, dilation=1, kernel_size=kernel, A=True, bias=True),\n",
    "        nn.LeakyReLU(),\n",
    "        CausalConv1d(in_channels=M_kernels, out_channels=M_kernels, dilation=1, kernel_size=kernel, A=False, bias=True),\n",
    "        nn.LeakyReLU(),\n",
    "        CausalConv1d(in_channels=M_kernels, out_channels=E, dilation=1, kernel_size=kernel, A=False, bias=True))\n",
    "\n",
    "    entropy_coding = ARMEntropyCoding(code_dim=C, codebook_dim=E, arm_net=arm_net)\n",
    "\n",
    "# MODEL\n",
    "model = NeuralCompressor(encoder=encoder, decoder=decoder, entropy_coding=entropy_coding, quantizer=quantizer, beta=beta)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's play! Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# OPTIMIZER\n",
    "optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Training procedure\n",
    "objective_loss_val, objective_distortion_val, objective_rate_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, \n",
    "                   optimizer=optimizer,\n",
    "                   training_loader=training_loader, val_loader=val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss, test_distortion, test_rate = evaluation(name=result_dir + name, test_loader=test_loader)\n",
    "f = open(result_dir + name + '_test_loss.txt', \"w\")\n",
    "f.write(str(test_loss) + ', ' + str(test_distortion) + ', ' + str(test_rate))\n",
    "f.close()\n",
    "\n",
    "plot_curve(result_dir + name + '_objective_', objective_loss_val, metric_name='objective')\n",
    "plot_curve(result_dir + name + '_distortion_', objective_distortion_val, metric_name='distortion')\n",
    "plot_curve(result_dir + name + '_rate_', objective_rate_val, metric_name='rate')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Qualitative inspection"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we visualize samples and reconstructions. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We specifies ids of images from the test set.\n",
    "IMG_IDs = [110, 120, 130, 140]\n",
    "\n",
    "# samples\n",
    "z_sampled = model.entropy_coding.sample(quantizer=model.quantizer, B=9)\n",
    "x_sampled = model.decoder(z_sampled)\n",
    "\n",
    "# reconstructions\n",
    "x_real = torch.from_numpy(test_data.__getitem__(IMG_IDs))\n",
    "x_rec = model.decoder(model.quantizer(model.encoder(x_real))[-1])\n",
    "\n",
    "# plotting\n",
    "fig, axs = plt.subplots(4, 3, figsize=(6, 8))\n",
    "i = 0\n",
    "for i in range(len(IMG_IDs)):\n",
    "    axs[i,0].imshow(x_real[i].reshape(8,8).detach().numpy())\n",
    "    axs[i,0].set_title('original')\n",
    "    axs[i,0].axis('off')\n",
    "    \n",
    "    axs[i,1].imshow(x_rec[i].reshape(8,8).detach().numpy())\n",
    "    axs[i,1].set_title('reconstruction')\n",
    "    axs[i,1].axis('off')\n",
    "    \n",
    "    axs[i,2].imshow(x_sampled[i].squeeze().reshape(8,8).detach().numpy())\n",
    "    axs[i,2].set_title('sample')\n",
    "    axs[i,2].axis('off')\n",
    "\n",
    "plt.savefig(result_dir + name + 'recon_sample.pdf', bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
